import torch
import torch.nn as nn
import torch.nn.functional as F
import time
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class MultiHeadVanillaAttention(nn.Module):
def __init__(self, d_model, head_dim, num_heads):
super().__init__()
self.d_model = d_model
self.d_k = head_dim * num_heads
self.head_dim = head_dim
self.num_heads = num_heads
self.Wq = nn.Linear(in_features = self.d_model, out_features = self.d_k)
self.Wk = nn.Linear(in_features = self.d_model, out_features = self.d_k)
self.Wv = nn.Linear(in_features = self.d_model, out_features = self.d_k)
self.out = nn.Linear(in_features = self.d_k, out_features = self.d_model)
def forward(self, x):
batch_size = x.shape[0]
query = self.Wq(x) # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
key = self.Wk(x) #[B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
value = self.Wv(x) # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
query = query.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # [B, N, d_k] -> [B, N, num_heads, head_dim] -> [B, num_heads, N, head_dim]
key = key.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # [B, N, d_k] -> [B, N, num_heads, head_dim] -> [B, num_heads, N, head_dim]
value = value.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # [B, N, d_k] -> [B, N, num_heads, head_dim] -> [B, num_heads, N, head_dim]
scores = torch.matmul(query, key.transpose(-1, -2))/(self.head_dim**0.5) # [B, num_heads, N, head_dim] * [B, num_heads, head_dim, N] -> [B, num_heads, N, N] TC=O(B*N^2*d_k)
attn = F.softmax(scores, dim=-1) # [B, num_heads, N, N]
final = torch.matmul(attn, value).permute(0, 2, 1, 3).reshape(batch_size, -1, self.d_k) # [B, num_heads, N, N] * [B, num_heads, N, head_dim] -> [B, N, d_k] TC=O(B*N^2*d_k)
projout = self.out(final) # [B, N, d_k] -> [B, N, d_model]
return projout # [B, N, d_model]
d_model = 512
d_k = 512
model = MultiHeadVanillaAttention(d_model = d_model, num_heads=2, head_dim=32).requires_grad_(False).eval().to(device)
n_tokens = 16384
inp = torch.randn(1, n_tokens, d_model, device=device)
torch.cuda.synchronize()
start = time.time()
k=50
for _ in range(k):
with torch.no_grad():
out = model(inp)
torch.cuda.synchronize()
print(f"Time taken: {(time.time()-start)/k:.3f}")